1667B - Optimal Partition - CodeForces Solution


data structures dp

Please click on ads to support us..

Python Code:

import io,os
import bisect
input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline


class segment_max(object):

 
        def merge(self,num,minimum):
        return max(minimum,num)


    def __init__(self,n):
        self.n = n
        self.arr = [-2147483647]*(2*n)




    def update(self,index,target):

        if self.arr[index] >= target:
            return 

        self.arr[index] = target
        if index & 1:
            nexttarget = self.merge( self.arr[index], self.arr[index-1])
        else:
            nexttarget = self.merge( self.arr[index], self.arr[index+1])
        if index>0:  self.update(index>>1,nexttarget )


    def addnum(self,index,diff):
        self.update(index+self.n, self.arr[index+self.n] + diff)


    def query(self,left,right):
        i,j = self.n+left,  self.n+right+1
        output = -2147483648  
        while i<j:
            if i&1:
                output = self.merge(self.arr[i],output)
                i += 1
            if j&1:
                j -= 1
                output = self.merge(self.arr[j],output)
            i = i >> 1
            j = j >> 1
        return output





def main(t):


    n = int(input())
    arr = list(map(int,input().split()))
    accu = [0]*(n+1)
    for i in range(n):
        accu[i] = accu[i-1] + arr[i]
    temp = []
    for i in range(n):
        temp.append((accu[i],i))
    temp.append((0,-1))
    temp.sort()
    cor = {}
    index = 0
    for [accunum,i] in temp:
        if str(accunum) not in cor:  
            cor[str(accunum)] = index
            index += 1

    


    seg_min = segment_max(n+1)
    seg_max = segment_max(n+1)
    index = cor[str(0)]

    seg_min.update(index+n+1,1)
    seg_max.update(index+n+1,-1)

    maxdp = {}
    curr = 0
    maxdp[str(curr)] = 0


    for i in range(n):
        index = cor[str(accu[i])]
        num1 = seg_min.query(0,index-1) + i
        num2 = seg_max.query(index+1,n) - i
        if str(accu[i]) in maxdp:  
            num3 = maxdp[str(accu[i])]
        else:  num3 = -2147483648

 
        num = max(num1,num2,num3)
        if str(accu[i]) in maxdp:  
            maxdp[str(accu[i])] = max(maxdp[str(accu[i])],num)
        else:
            maxdp[str(accu[i])] = num

        seg_min.update(index+n+1, num-i )
        seg_max.update(index+n+1, num+i )

    print(num)    

        
        
        






























T = int(input())
t = 1
while t<=T:
    main(t)
    t += 1

C++ Code:

#include <bits/stdc++.h>


#define ll long long
#define lld long double
#define ff first
#define ss second
#define pb push_back
#define vr(v) v.begin(),v.end()
#define rv(v) v.rbegin(),v.rend()
#define Code ios_base::sync_with_stdio(false);
#define By cin.tie(NULL);
#define Davit cout.tie(NULL);

using namespace std;

//#include "algo/debug.h"

struct node {
    ll pos, neg;
};
vector<node> seg;
int sz = 1;

node merge(node a, node b) {
    return {max(a.pos, b.pos), max(a.neg, b.neg)};
}


void modify(int l, int r, int x, int index, node value) {
    if (l == r) {
        seg[x] = merge(seg[x], value);
        return;
    }
    int m = (l + r) / 2;
    if (m >= index) modify(l, m, x + x + 1, index, value);
    else modify(m + 1, r, x + x + 2, index, value);
    seg[x] = merge(seg[x + x + 1], seg[x + x + 2]);

}

void modify(int index, node value) {
    modify(0, sz - 1, 0, index, value);
}


node get(int l, int r, int x, int lx, int rx) {
    if (l >= lx && r <= rx)return seg[x];
    if (l > rx || r < lx)return {(ll) -1e9, (ll) -1e9};
    int m = (l + r) >> 1;
    node a = get(l, m, x + x + 1, lx, rx);
    node b = get(m + 1, r, x + x + 2, lx, rx);
    return merge(a, b);
}

node get(int l, int r) {
    return get(0, sz - 1, 0, l, r);
}


void solve() {
    int n;
    cin >> n;
    vector<ll> v(n);
    for (int i = 0; i < n; i++)cin >> v[i];
    vector<ll> pref(n + 1);
    for (int i = 1; i <= n; i++) {
        pref[i] = v[i - 1] + pref[i - 1];
    }
    vector<ll> compress = pref;
    sort(vr(compress));
    compress.resize(unique(vr(compress)) - compress.begin());
    for (int i = 0; i <= n; i++) {
        pref[i] = lower_bound(vr(compress), pref[i]) - compress.begin();
    }

    sz = (int) compress.size();
    int SZ = 1;
    while (SZ < sz)SZ <<= 1;
    sz = SZ;
    seg = vector<node>(sz + sz, {(ll) -1e9, (ll) -1e9});
    vector<ll> dp(n + 1, -1e9);
    vector<ll> zero(sz, -1e9);
    dp[0] = 0;
    for (int i = 0; i <= n; i++) {
        int index = (int) pref[i];
        if (i) {
            dp[i] = zero[index];
            dp[i] = max(dp[i], get(0, index - 1).pos + i);
            dp[i] = max(dp[i], get(index + 1, sz - 1).neg - i);
        }
        zero[pref[i]] = max(zero[index], dp[i]);
        modify(index, {dp[i] - i, dp[i] + i});
    }
    cout << dp[n] << endl;


}

int main() {
    int t;
    cin >> t;
    while (t--)solve();
}


Comments

Submit
0 Comments
More Questions

1302. Deepest Leaves Sum
1209. Remove All Adjacent Duplicates in String II
994. Rotting Oranges
983. Minimum Cost For Tickets
973. K Closest Points to Origin
969. Pancake Sorting
967. Numbers With Same Consecutive Differences
957. Prison Cells After N Days
946. Validate Stack Sequences
921. Minimum Add to Make Parentheses Valid
881. Boats to Save People
497. Random Point in Non-overlapping Rectangles
528. Random Pick with Weight
470. Implement Rand10() Using Rand7()
866. Prime Palindrome
1516A - Tit for Tat
622. Design Circular Queue
814. Binary Tree Pruning
791. Custom Sort String
787. Cheapest Flights Within K Stops
779. K-th Symbol in Grammar
701. Insert into a Binary Search Tree
429. N-ary Tree Level Order Traversal
739. Daily Temperatures
647. Palindromic Substrings
583. Delete Operation for Two Strings
518. Coin Change 2
516. Longest Palindromic Subsequence
468. Validate IP Address
450. Delete Node in a BST